import numpy as np
import pandas as pd

from utils.transform_utils import *


def prepare_dataset(name, dataset):
    # Preference Datasets
    if "openai_summarize_scores" in name:
        posts = dataset["info"].apply(lambda x: x["title"].replace("<|endoftext|>", ""))
        scores = dataset["summary"].apply(lambda x: x["axes"])
        summaries = dataset["summary"].apply(
            lambda x: x["text"].replace("<|endoftext|>", "")
        )
        return pd.DataFrame(
            {
                "original": posts,
                "summary": summaries,
                "overall": scores.apply(lambda x: x["overall"]),
                "accuracy": scores.apply(lambda x: x["accuracy"]),
                "coverage": scores.apply(lambda x: x["coverage"]),
                "coherence": scores.apply(lambda x: x["coherence"]),
            }
        )
    if "openai_summarize_comparisons" in name:
        posts = dataset["info"].apply(lambda x: x["post"].replace("<|endoftext|>", ""))
        summary1 = dataset["summaries"].apply(
            lambda x: x[0]["text"].replace("<|endoftext|>", "")
        )
        summary2 = dataset["summaries"].apply(
            lambda x: x[1]["text"].replace("<|endoftext|>", "")
        )
        choices = dataset["choice"]
        return pd.DataFrame(
            {
                "original": posts,
                "summary1": summary1,
                "summary2": summary2,
                "choice": choices,
            }
        )

    # Robustness Datasets
    if "scientific_papers_robustness" in name:
        df = pd.DataFrame(
            {"original": dataset["article"], "summary": dataset["abstract"]}
        )
        df = df[~(df == "").any(axis=1)]
        df = df.reset_index(drop=True)
    if "cnn_dailymail" in name:
        df = pd.DataFrame(
            {"original": dataset["article"], "summary": dataset["highlights"]}
        )
    if "billsum" in name:
        df = pd.DataFrame({"original": dataset["text"], "summary": dataset["summary"]})
        df = df[~(df == "").any(axis=1)]
        df = df.reset_index(drop=True)

    if (
        "cnn_dailymail" in name
        or "scientific_papers_robustness" in name
        or "billsum" in name
    ):
        df["summary_sentence_shuffled"] = shuffle_text(df["summary"])
        df["original_sentence_shuffled"] = shuffle_text(df["original"])
        df["summary_word_shuffled"] = shuffle_words(df["summary"])
        df["original_word_shuffled"] = shuffle_words(df["original"])
        df["summary_negated"] = negate_text(df["summary"])
        df["original_negated"] = negate_text(df["original"])

        df["summary_pruned"] = prune_text(df["summary"], 10)
        df["original_pruned"] = prune_text(df["original"], 10)
        df["summary_random_upper"] = capitalize_random(df["summary"])
        df["original_random_upper"] = capitalize_random(df["original"])
        df["summary_numerized"] = numerize_text(df["summary"])
        df["original_numerized"] = numerize_text(df["original"])
        return df

    # Sensitivity Datasets
    if "scientific_papers_sensitivity" in name:
        return pd.DataFrame({"original": dataset["abstract"]})
    if "paul_graham" in name:
        return pd.DataFrame({"original": dataset["text"]})
    if "amazon_polarity" in name:
        return pd.DataFrame({"original": dataset["content"]})
    if "arguana" in name:
        return pd.DataFrame({"original": dataset["text"]})
    if "reddit" in name and "clustering" not in name:
        return pd.DataFrame({"original": np.concatenate(dataset["sentences"])})

    # Clustering Datasets
    if "clustering" in name:
        sentences, labels = [], []
        for sentence in dataset["sentences"]:
            sentences.extend(sentence)
        for label in dataset["labels"]:
            labels.extend(label)
        return pd.DataFrame({"original": sentences, "labels": labels})
    return dataset
